{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "eeg_mac.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "Yy0WHpW3X9oX",
        "9SwIjOX9frk6",
        "ZjxFdZTXBZWb",
        "aBusNmsKpc2Z",
        "NwwLDouezDLc",
        "VpMowosc0Bzf"
      ],
      "toc_visible": true,
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/vitaldb/examples/blob/master/eeg_mac.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yy0WHpW3X9oX"
      },
      "source": [
        "# 뇌파로부터 마취제 농도 예측 인공지능 모델 실습\n",
        "Sevoflurane 마취 중 뇌파로부터 마취제 농도(age related MAC) 예측 모델\n",
        "\n",
        "## VitalDB 데이터 셋 이용\n",
        "본 예제에서는 오픈 생체 신호 데이터셋인 VitalDB를 이용하는 모든 사용자는 반드시 아래 Data Use Agreement에 동의하여야 합니다.\n",
        "\n",
        "https://vitaldb.net/data-bank/?query=guide&documentId=13qqajnNZzkN7NZ9aXnaQ-47NWy7kx-a6gbrcEsi-gak&sectionId=h.usmoena3l4rb\n",
        "\n",
        "동의하지 않을 경우 이 창을 닫으세요."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9SwIjOX9frk6"
      },
      "source": [
        "## 본 프로그램에서 이용할 라이브러리 설치 및 import"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ch6czkFZfw_G",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "d563f693-f8c5-46bb-9c02-e9f207657a43"
      },
      "source": [
        "!pip install vitaldb\n",
        "\n",
        "import vitaldb\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import os\n",
        "import scipy.signal\n",
        "import matplotlib.pyplot as plt\n",
        "import random\n",
        "import itertools as it\n",
        "import numpy as np\n",
        "from matplotlib import pyplot as plt"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting vitaldb\n",
            "  Downloading vitaldb-0.0.11-py3-none-any.whl (42 kB)\n",
            "\u001b[?25l\r\u001b[K     |███████▋                        | 10 kB 21.4 MB/s eta 0:00:01\r\u001b[K     |███████████████▎                | 20 kB 25.2 MB/s eta 0:00:01\r\u001b[K     |███████████████████████         | 30 kB 12.5 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▋ | 40 kB 9.6 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 42 kB 744 kB/s \n",
            "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from vitaldb) (2.23.0)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from vitaldb) (1.1.5)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from vitaldb) (1.19.5)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->vitaldb) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->vitaldb) (2018.9)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->vitaldb) (1.15.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->vitaldb) (2021.5.30)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->vitaldb) (2.10)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->vitaldb) (3.0.4)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->vitaldb) (1.24.3)\n",
            "Installing collected packages: vitaldb\n",
            "Successfully installed vitaldb-0.0.11\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZjxFdZTXBZWb"
      },
      "source": [
        "# Data loading 및 전처리"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AabTJc6dflSy"
      },
      "source": [
        "VitalDB Web API를 통해 데이터 로딩\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0b_TyfelWg6e",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3b28da82-4a5e-4428-8157-cc167a49d442"
      },
      "source": [
        "SRATE = 128  # in hz\n",
        "SEGLEN = 4 * SRATE  # samples\n",
        "BATCH_SIZE = 1024\n",
        "MAX_CASES = 100\n",
        "\n",
        "cachefile = '{}sec_{}cases.npz'.format(SEGLEN // SRATE, MAX_CASES)\n",
        "if os.path.exists(cachefile):\n",
        "    dat = np.load(cachefile)\n",
        "    x, y, b, c = dat['x'], dat['y'], dat['b'], dat['c']\n",
        "else:\n",
        "    df_trks = pd.read_csv(\"https://api.vitaldb.net/trks\")  # 트랙 정보\n",
        "    df_cases = pd.read_csv(\"https://api.vitaldb.net/cases\")  # 환자 정보\n",
        "\n",
        "    # 데이터 로딩 시 컬럼 순서\n",
        "    EEG = 0\n",
        "    SEVO = 1\n",
        "    BIS = 2\n",
        "\n",
        "    # inclusion & exclusion criteria\n",
        "    caseids = set(df_cases.loc[df_cases['age'] > 18, 'caseid']) &\\\n",
        "        set(df_trks.loc[df_trks['tname'] == 'BIS/EEG1_WAV', 'caseid']) &\\\n",
        "        set(df_trks.loc[df_trks['tname'] == 'BIS/BIS', 'caseid']) &\\\n",
        "        set(df_trks.loc[df_trks['tname'] == 'Primus/EXP_SEVO', 'caseid'])\n",
        "\n",
        "    x = []  \n",
        "    y = []  # sevo\n",
        "    b = []  # bis\n",
        "    c = []  # caseids\n",
        "    icase = 0  # 현재까지 로딩된 케이스 수\n",
        "    for caseid in caseids:\n",
        "        print('loading {} ({}/{})'.format(caseid, icase, MAX_CASES), end='...', flush=True)\n",
        "\n",
        "        # 아래 값들이 있으면 제외\n",
        "        if np.any(vitaldb.load_case(caseid, 'Orchestra/PPF20_CE') > 0.2):\n",
        "            print('propofol')\n",
        "            continue\n",
        "        if np.any(vitaldb.load_case(caseid, 'Primus/EXP_DES') > 1):\n",
        "            print('desflurane')\n",
        "            continue\n",
        "        if np.any(vitaldb.load_case(caseid, 'Primus/FEN2O') > 2):\n",
        "            print('n2o')\n",
        "            continue\n",
        "        if np.any(vitaldb.load_case(caseid, 'Orchestra/RFTN50_CE') > 0.2):\n",
        "            print('remifentanil')\n",
        "            continue\n",
        "\n",
        "        # extract data\n",
        "        vals = vitaldb.load_case(caseid, ['BIS/EEG1_WAV', 'Primus/EXP_SEVO', 'BIS/BIS'], 1 / SRATE)\n",
        "        if np.nanmax(vals[:, SEVO]) < 1:\n",
        "            print('all sevo <= 1')\n",
        "            continue\n",
        "\n",
        "        # convert etsevo to the age related mac\n",
        "        age = df_cases.loc[df_cases['caseid'] == caseid, 'age'].values[0]\n",
        "        vals[:, SEVO] /= 1.80 * 10 ** (-0.00269 * (age - 40))\n",
        "\n",
        "        if not np.any(vals[:, BIS] > 0):\n",
        "            print('all bis <= 0')\n",
        "            continue\n",
        "\n",
        "        # 뇌파가 잘 나와야 하기 때문에 bis가 값이 처음으로 계산되어 나온 곳 부터 시작함\n",
        "        valid_bis_idx = np.where(vals[:, BIS] > 0)[0]\n",
        "        first_bis_idx = valid_bis_idx[0]\n",
        "        last_bis_idx = valid_bis_idx[-1]\n",
        "        vals = vals[first_bis_idx:last_bis_idx + 1, :]\n",
        "\n",
        "        if len(vals) < 1800 * SRATE:  # 30분 이하인 case는 사용하지 않음\n",
        "            print('{} len < 30 min'.format(caseid))\n",
        "            continue\n",
        "\n",
        "        # MAC 값과 BIS 값은 5초까지 forward filling\n",
        "        vals[:, SEVO:] = pd.DataFrame(vals[:, SEVO:]).ffill(limit=5 * SRATE).values\n",
        "\n",
        "        # case 시작 부터 종료까지 1초 간격으로 데이터 추출하여 dataset 에 넣음\n",
        "        oldlen = len(y)\n",
        "        for irow in range(SEGLEN, len(vals), SRATE):\n",
        "            bis = vals[irow, BIS]\n",
        "            mac = vals[irow, SEVO]\n",
        "            if np.isnan(bis) or np.isnan(mac) or bis == 0:\n",
        "                continue\n",
        "            # dataset 에 추가\n",
        "            eeg = vals[irow - SEGLEN:irow, EEG]\n",
        "            x.append(eeg)\n",
        "            y.append(mac)\n",
        "            b.append(bis)\n",
        "            c.append(caseid)\n",
        "\n",
        "        # valid case\n",
        "        icase += 1\n",
        "        print('{} samples read -> total {} samples ({}/{})'.format(len(y) - oldlen, len(y), icase, MAX_CASES))\n",
        "        if icase >= MAX_CASES:\n",
        "            break\n",
        "\n",
        "    # 입력 데이터셋을 numpy array로 변경\n",
        "    x = np.array(x)\n",
        "    y = np.array(y)\n",
        "    b = np.array(b)\n",
        "    c = np.array(c)\n",
        "\n",
        "    # save cahce file\n",
        "    np.savez(cachefile, x=x, y=y, b=b, c=c)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading 1 (0/100)...desflurane\n",
            "loading 2 (0/100)...10381 samples read -> total 10381 samples (1/100)\n",
            "loading 3 (1/100)...propofol\n",
            "loading 4 (1/100)...14367 samples read -> total 24748 samples (2/100)\n",
            "loading 5 (2/100)...propofol\n",
            "loading 10 (2/100)...14509 samples read -> total 39257 samples (3/100)\n",
            "loading 12 (3/100)...21126 samples read -> total 60383 samples (4/100)\n",
            "loading 18 (4/100)...all bis <= 0\n",
            "loading 19 (4/100)...propofol\n",
            "loading 20 (4/100)...propofol\n",
            "loading 21 (4/100)...8087 samples read -> total 68470 samples (5/100)\n",
            "loading 24 (5/100)...3599 samples read -> total 72069 samples (6/100)\n",
            "loading 25 (6/100)...9665 samples read -> total 81734 samples (7/100)\n",
            "loading 26 (7/100)...desflurane\n",
            "loading 27 (7/100)...11541 samples read -> total 93275 samples (8/100)\n",
            "loading 30 (8/100)...propofol\n",
            "loading 33 (8/100)...2797 samples read -> total 96072 samples (9/100)\n",
            "loading 34 (9/100)...propofol\n",
            "loading 35 (9/100)...propofol\n",
            "loading 38 (9/100)...propofol\n",
            "loading 43 (9/100)...9689 samples read -> total 105761 samples (10/100)\n",
            "loading 44 (10/100)...propofol\n",
            "loading 49 (10/100)...6768 samples read -> total 112529 samples (11/100)\n",
            "loading 52 (11/100)...propofol\n",
            "loading 54 (11/100)...propofol\n",
            "loading 56 (11/100)...19704 samples read -> total 132233 samples (12/100)\n",
            "loading 57 (12/100)...propofol\n",
            "loading 58 (12/100)...9825 samples read -> total 142058 samples (13/100)\n",
            "loading 60 (13/100)...propofol\n",
            "loading 61 (13/100)...5851 samples read -> total 147909 samples (14/100)\n",
            "loading 62 (14/100)...6192 samples read -> total 154101 samples (15/100)\n",
            "loading 64 (15/100)...8523 samples read -> total 162624 samples (16/100)\n",
            "loading 65 (16/100)...8276 samples read -> total 170900 samples (17/100)\n",
            "loading 66 (17/100)...6016 samples read -> total 176916 samples (18/100)\n",
            "loading 67 (18/100)...propofol\n",
            "loading 71 (18/100)...desflurane\n",
            "loading 74 (18/100)...propofol\n",
            "loading 76 (18/100)...9278 samples read -> total 186194 samples (19/100)\n",
            "loading 79 (19/100)...desflurane\n",
            "loading 80 (19/100)...1655 samples read -> total 187849 samples (20/100)\n",
            "loading 82 (20/100)...1981 samples read -> total 189830 samples (21/100)\n",
            "loading 84 (21/100)...12047 samples read -> total 201877 samples (22/100)\n",
            "loading 85 (22/100)...desflurane\n",
            "loading 87 (22/100)...6706 samples read -> total 208583 samples (23/100)\n",
            "loading 89 (23/100)...13117 samples read -> total 221700 samples (24/100)\n",
            "loading 91 (24/100)...9129 samples read -> total 230829 samples (25/100)\n",
            "loading 92 (25/100)...2898 samples read -> total 233727 samples (26/100)\n",
            "loading 95 (26/100)...2618 samples read -> total 236345 samples (27/100)\n",
            "loading 96 (27/100)...26293 samples read -> total 262638 samples (28/100)\n",
            "loading 97 (28/100)...propofol\n",
            "loading 98 (28/100)...4291 samples read -> total 266929 samples (29/100)\n",
            "loading 100 (29/100)...2366 samples read -> total 269295 samples (30/100)\n",
            "loading 103 (30/100)...propofol\n",
            "loading 104 (30/100)...propofol\n",
            "loading 105 (30/100)...11788 samples read -> total 281083 samples (31/100)\n",
            "loading 107 (31/100)...desflurane\n",
            "loading 108 (31/100)...desflurane\n",
            "loading 110 (31/100)...4330 samples read -> total 285413 samples (32/100)\n",
            "loading 111 (32/100)...propofol\n",
            "loading 112 (32/100)...7104 samples read -> total 292517 samples (33/100)\n",
            "loading 114 (33/100)...10801 samples read -> total 303318 samples (34/100)\n",
            "loading 116 (34/100)...7543 samples read -> total 310861 samples (35/100)\n",
            "loading 117 (35/100)...7608 samples read -> total 318469 samples (36/100)\n",
            "loading 118 (36/100)...propofol\n",
            "loading 123 (36/100)...all sevo <= 1\n",
            "loading 126 (36/100)...propofol\n",
            "loading 128 (36/100)...propofol\n",
            "loading 130 (36/100)...all sevo <= 1\n",
            "loading 131 (36/100)...all sevo <= 1\n",
            "loading 132 (36/100)...propofol\n",
            "loading 133 (36/100)...6908 samples read -> total 325377 samples (37/100)\n",
            "loading 135 (37/100)...7946 samples read -> total 333323 samples (38/100)\n",
            "loading 137 (38/100)...propofol\n",
            "loading 138 (38/100)...propofol\n",
            "loading 139 (38/100)...propofol\n",
            "loading 140 (38/100)...propofol\n",
            "loading 141 (38/100)...propofol\n",
            "loading 142 (38/100)...10555 samples read -> total 343878 samples (39/100)\n",
            "loading 146 (39/100)...10773 samples read -> total 354651 samples (40/100)\n",
            "loading 148 (40/100)...8421 samples read -> total 363072 samples (41/100)\n",
            "loading 149 (41/100)...propofol\n",
            "loading 151 (41/100)...propofol\n",
            "loading 156 (41/100)...propofol\n",
            "loading 159 (41/100)...desflurane\n",
            "loading 160 (41/100)...propofol\n",
            "loading 161 (41/100)...12117 samples read -> total 375189 samples (42/100)\n",
            "loading 163 (42/100)...propofol\n",
            "loading 165 (42/100)...propofol\n",
            "loading 166 (42/100)...propofol\n",
            "loading 168 (42/100)...all sevo <= 1\n",
            "loading 172 (42/100)...propofol\n",
            "loading 173 (42/100)...desflurane\n",
            "loading 174 (42/100)...2298 samples read -> total 377487 samples (43/100)\n",
            "loading 175 (43/100)...desflurane\n",
            "loading 176 (43/100)...all sevo <= 1\n",
            "loading 177 (43/100)...propofol\n",
            "loading 178 (43/100)...8810 samples read -> total 386297 samples (44/100)\n",
            "loading 179 (44/100)...desflurane\n",
            "loading 181 (44/100)...propofol\n",
            "loading 182 (44/100)...propofol\n",
            "loading 184 (44/100)...23907 samples read -> total 410204 samples (45/100)\n",
            "loading 188 (45/100)...all sevo <= 1\n",
            "loading 190 (45/100)...10858 samples read -> total 421062 samples (46/100)\n",
            "loading 191 (46/100)...desflurane\n",
            "loading 193 (46/100)...propofol\n",
            "loading 194 (46/100)...propofol\n",
            "loading 195 (46/100)..."
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NwwLDouezDLc"
      },
      "source": [
        "## 뇌파 입력 데이터 필터링"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bHa2fte0zGEv"
      },
      "source": [
        "# 결측값이 있으면 제거\n",
        "print('invalid samples...', end='', flush=True)\n",
        "valid_mask = ~(np.max(np.isnan(x), axis=1) > 0) # nan이 있으면 제거\n",
        "valid_mask &= (np.max(x, axis=1) - np.min(x, axis=1) > 12)  # bis 임피던스 체크 eeg의 전체 range가 12 미만이면 제거\n",
        "x = x[valid_mask, :]\n",
        "y = y[valid_mask]\n",
        "b = b[valid_mask]\n",
        "c = c[valid_mask]\n",
        "print('{:.1f}% removed'.format(100*(1-np.mean(valid_mask))))\n",
        "\n",
        "# 필터링\n",
        "print('baseline drift...', end='', flush=True)\n",
        "x -= scipy.signal.savgol_filter(x, 91, 3)  # remove baseline drift\n",
        "print('removed')\n",
        "\n",
        "# noise 가 많으면 제거\n",
        "print('noisy samples...', end='', flush=True)\n",
        "valid_mask = (np.nanmax(np.abs(x), axis=1) < 100) # noisy sample \n",
        "\n",
        "x = x[valid_mask, :]  # CNN 에 넣기 위해서는 3차원이어야 한다. 마지막 차원을 추가\n",
        "y = y[valid_mask]\n",
        "b = b[valid_mask]\n",
        "c = c[valid_mask]\n",
        "print('{:.1f}% removed'.format(100*(1-np.mean(valid_mask))))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VpMowosc0Bzf"
      },
      "source": [
        "## 데이터를 학습(train)과 테스트(test)로 나누기"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AjdvjsJ10JAe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "dedf00e9-423c-4ad4-ab83-2f22b254f1be"
      },
      "source": [
        "# 최종적으로 로딩 된 caseid\n",
        "caseids = list(np.unique(c))\n",
        "random.shuffle(caseids)\n",
        "\n",
        "# case 단위로 train, test case로 나눔\n",
        "ntest = max(1, int(len(caseids) * 0.2))\n",
        "caseids_train = caseids[ntest:]\n",
        "caseids_test = caseids[:ntest]\n",
        "\n",
        "train_mask = np.isin(c, caseids_train)\n",
        "test_mask = np.isin(c, caseids_test)\n",
        "x_train = x[train_mask]\n",
        "y_train = y[train_mask]\n",
        "x_test = x[test_mask]\n",
        "y_test = y[test_mask]\n",
        "b_test = b[test_mask]\n",
        "c_test = c[test_mask]\n",
        "\n",
        "print('====================================================')\n",
        "print('total: {} cases {} samples'.format(len(caseids), len(y)))\n",
        "print('train: {} cases {} samples'.format(len(np.unique(c[train_mask])), len(y_train)))\n",
        "print('test {} cases {} samples'.format(len(np.unique(c_test)), len(y_test)))\n",
        "print('====================================================')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train: 45 cases 53222 samples, testing: 5 cases 5266 samples\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPi6FIbB0NkF"
      },
      "source": [
        "# Model building\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uytYU2Fu0rPi",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7c66ef36-c08e-4e18-8f25-6d6d3e912bef"
      },
      "source": [
        "import keras.models\n",
        "import tensorflow as tf\n",
        "from keras.models import Model\n",
        "from keras.layers import Layer, LayerNormalization, Dense, Dropout, Conv1D, MaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling1D, Input, concatenate, multiply, dot, MultiHeadAttention\n",
        "from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
        "\n",
        "# hyperparameters\n",
        "tests = {\n",
        "    \"nfilt\" : [16, 32, 64],\n",
        "    \"fnode\" : [32, 64, 128],\n",
        "    \"clayer\" : [1, 2, 3, 4],\n",
        "    \"droprate\" : [0.1, 0.2],\n",
        "    \"filtsize\" : [5, 7, 9, 11],\n",
        "    'poolsize' : [2, 4, 8],\n",
        "    \"pooltype\" : ['avg', 'max']\n",
        "}\n",
        "\n",
        "# https://keras.io/examples/nlp/text_classification_with_transformer/\n",
        "keys, values = zip(*tests.items())\n",
        "permutations_dicts = it.product(*values)\n",
        "permutations_dicts = list(permutations_dicts)\n",
        "random.shuffle(permutations_dicts)\n",
        "for nfilt, fnode, clayer, droprate, filtsize, poolsize, pooltype in permutations_dicts:\n",
        "\n",
        "    keras.backend.clear_session()\n",
        "    \n",
        "    odir = '{}cases_{}sec'.format(MAX_CASES, SEGLEN // SRATE)\n",
        "    odir += '_cnn{} filt{} size{} pool{} {} do{}'.format(clayer, nfilt, filtsize, poolsize, pooltype, droprate)\n",
        "    print(\"============================\")\n",
        "    print(odir)\n",
        "    print(\"============================\")\n",
        "\n",
        "    out = inp = Input(shape=(x_train.shape[1], 1))\n",
        "    # initial cnn layer\n",
        "    out = Conv1D(filters=nfilt, kernel_size=filtsize, padding='same')(out)\n",
        "    # conv 여러층    \n",
        "    for i in range(clayer):\n",
        "        out = Conv1D(filters=nfilt, kernel_size=filtsize, padding='same', activation='relu')(out)\n",
        "        out = MaxPooling1D(poolsize, padding='same')(out)\n",
        "    if pooltype == \"avg\":\n",
        "        out = GlobalAveragePooling1D()(out)\n",
        "    else:\n",
        "        out = GlobalMaxPooling1D()(out)\n",
        "\n",
        "    if droprate:\n",
        "        out = Dropout(droprate)(out)\n",
        "    out = Dense(fnode)(out)\n",
        "    if droprate:\n",
        "        out = Dropout(droprate)(out)\n",
        "    out = Dense(1)(out)\n",
        "\n",
        "    if not os.path.exists(odir):\n",
        "        os.mkdir(odir)\n",
        "\n",
        "    cache_path = odir + \"/weights.hdf5\"\n",
        "    model = Model(inputs=[inp], outputs=[out])\n",
        "    model.summary()\n",
        "    model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['mean_absolute_error'])\n",
        "    hist = model.fit(x_train[..., None], y_train, validation_split=0.2, epochs=10, batch_size=BATCH_SIZE,\n",
        "                    callbacks=[ModelCheckpoint(monitor='val_loss', filepath=cache_path, verbose=1, save_best_only=True),\n",
        "                               EarlyStopping(monitor='val_loss', patience=1, verbose=1, mode='auto'),\n",
        "                               ])\n",
        "\n",
        "    # prediction\n",
        "    pred_test = model.predict(x_test[..., None], batch_size=BATCH_SIZE).flatten()\n",
        "\n",
        "    # 성능을 계산하여 출력\n",
        "    test_mae = np.mean(np.abs(y_test - pred_test))\n",
        "    for caseid in np.unique(c_test):\n",
        "        case_mask = (c_test == caseid)\n",
        "        pred_test[case_mask] = scipy.signal.medfilt(pred_test[case_mask], 31)\n",
        "\n",
        "    # prediction 및 그림\n",
        "    for caseid in np.unique(c_test):\n",
        "        case_mask = (c_test == caseid)\n",
        "        case_len = np.sum(case_mask)\n",
        "        if case_len == 0:\n",
        "            continue\n",
        "\n",
        "        our_mae = np.mean(np.abs(y_test[case_mask] - pred_test[case_mask]))\n",
        "        print('Total MAE={:.4f}, CaseID {}, MAE={:.4f}'.format(test_mae, caseid, our_mae))\n",
        "\n",
        "        t = np.arange(0, case_len)\n",
        "        plt.figure(figsize=(20, 5))\n",
        "        plt.plot(t, y_test[case_mask], label='MAC')  # 측정 결과 \n",
        "        plt.plot(t, pred_test[case_mask], label='Ours ({:.4f})'.format(our_mae))\n",
        "        plt.legend(loc=\"upper left\")\n",
        "        plt.tight_layout()\n",
        "        plt.xlim([0, case_len])\n",
        "        plt.ylim([0, 2])\n",
        "        plt.show()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/100\n",
            "100/100 [==============================] - 19s 110ms/step - loss: 0.1017 - mean_absolute_percentage_error: 23.7280 - val_loss: 0.1096 - val_mean_absolute_percentage_error: 30.3796\n",
            "Epoch 2/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0933 - mean_absolute_percentage_error: 21.3791 - val_loss: 0.1070 - val_mean_absolute_percentage_error: 29.2109\n",
            "Epoch 3/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0877 - mean_absolute_percentage_error: 20.2326 - val_loss: 0.0990 - val_mean_absolute_percentage_error: 28.2095\n",
            "Epoch 4/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0830 - mean_absolute_percentage_error: 19.7110 - val_loss: 0.0982 - val_mean_absolute_percentage_error: 28.2541\n",
            "Epoch 5/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0815 - mean_absolute_percentage_error: 19.4846 - val_loss: 0.0978 - val_mean_absolute_percentage_error: 28.3904\n",
            "Epoch 6/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0809 - mean_absolute_percentage_error: 19.3659 - val_loss: 0.0948 - val_mean_absolute_percentage_error: 27.6144\n",
            "Epoch 7/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0803 - mean_absolute_percentage_error: 19.2700 - val_loss: 0.0983 - val_mean_absolute_percentage_error: 28.3905\n",
            "Epoch 8/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0804 - mean_absolute_percentage_error: 19.2611 - val_loss: 0.0950 - val_mean_absolute_percentage_error: 27.9300\n",
            "Epoch 9/100\n",
            "100/100 [==============================] - 10s 102ms/step - loss: 0.0800 - mean_absolute_percentage_error: 19.1936 - val_loss: 0.0981 - val_mean_absolute_percentage_error: 28.8082\n"
          ]
        }
      ]
    }
  ]
}